# Metrics for track 4 - point cloud classification

import argparse
from pathlib import Path
import numpy as np
import re

# Add classification labels we're tracking here, this list is self sorting
NULL_CLASS = 0  # not classified in ground_truth
LABELS_OBJ = {2: "Ground", 5: "High Vegetation", 6: "Building", 9: "Water", 17: "Bridge Deck"}
LABELS = sorted(LABELS_OBJ.keys())
LABEL_INDEXES = dict([(label, index) for index, label in enumerate(sorted(LABELS_OBJ.keys()))])


class ClassificationScore:
    """
        Used to store and compute metric scores
    """

    def __init__(self, true_positive=0, false_negative=0, false_positive=0):
        self.true_positive = true_positive
        self.false_negative = false_negative
        self.false_positive = false_positive

    def get_iou(self):
        denominator = self.true_positive + self.false_negative + self.false_positive
        if denominator == 0:
            return 1
        return self.true_positive / denominator

    def add(self, other):
        self.true_positive += other.true_positive
        self.false_positive += other.false_positive
        self.false_negative += other.false_negative


def generate_confusion_matrix(zipped):
    """
    Creates a confusion matrix from two aligned data sets
    :param zipped: an array of tuples representing 2 arrays joined by index
    :return: an MxM matrix of predictions vs truth
    """
    invalid_labels = []
    dim = len(LABELS)
    matrix = np.zeros((dim, dim + 1), np.uint)
    skipped_count = 0
    for truth_val, prediction_val in zipped:
        truth_index = LABEL_INDEXES.get(truth_val)
        prediction_index = LABEL_INDEXES.get(prediction_val)
        # if a label is found, but not tracked by this software, log it
        if truth_index is None:
            skipped_count += 1
            if truth_val == NULL_CLASS:
                continue  # This was junk, can't reliably score it
            if truth_val not in invalid_labels:
                print("Invalid truth LABEL: {}".format(truth_val))
                invalid_labels.append(truth_val)
            continue
        if prediction_index is None:
            if prediction_val not in invalid_labels:
                invalid_labels.append(prediction_val)
                print("Invalid prediction LABEL: {}".format(prediction_val))
            prediction_index = -1
        matrix[truth_index][prediction_index] += 1
    if skipped_count:
        print("Skipped {} unlabeled truth points".format(skipped_count))
    return matrix


def print_matrix(mat):
    """
    Prints the confusion matrix to terminal in a human readable manner
    :param mat: Confusion matrix generated by this software
    :return: nothing
    """
    across = "T"
    ALL_LABELS = LABELS + ['OTHER']
    for label in ALL_LABELS:
        across += "{:>9} ".format("P({})".format(label))
    print(across)
    print('-' * len(across))
    for label in LABELS:
        across = "{:<3}".format(label)
        for label_inner in ALL_LABELS:
            i = LABEL_INDEXES[label]
            j = -1 if label_inner is 'OTHER' else LABEL_INDEXES[label_inner]
            across += "{:>9}|".format(mat[i][j])
        print(across)
    print()


def get_overall_accuracy(matrix):
    """
    Computes the overall accuracy of predictions given a confusion matrix
    :param matrix: Confusion matrix for a prediction event
    :return: overall accuracy between 0 -> 1 where equates to 100%
    """
    count = matrix.sum()
    correct = 0
    for ix in range(len(LABELS)):
        correct += matrix[ix][ix]
    if count == 0:
        return 1
    return correct / count


def get_mean_intersection_over_union(scores):
    """
    MIOU generated by averaging IOUs of each class, unweighted
    :param scores: an array of ClassificationScore objects
    :return: 0->1 where 1 implies perfect intersection
    """
    iou = 0
    for score in scores:
        iou += score.get_iou()
    if len(scores) < 1:
        return 0
    return iou / len(scores)


def score_predictions(matrix):
    """
    Calculates the number of true positives, false negatives and false positives for each class in a confusion matrix
    :param matrix: Confusion matrix for a prediction event
    :return: An array of ClassificationScores, one for each classification label
    """
    per_class = []
    for current_class_index in range(len(LABELS)):
        true_positives = matrix[current_class_index][current_class_index]
        # false negatives are the points where the truth is the class, but the prediction is not (so sum across the
        # row and subtract off true positives)
        false_negatives = np.sum(matrix[current_class_index, :]) - true_positives
        # false negatives are the points where the prediction is the class, but the truth is not (so sum across the
        # column and subtract off true positives)
        false_positives = np.sum(matrix[:, current_class_index]) - true_positives
        per_class.append(ClassificationScore(true_positives, false_negatives, false_positives))
    return per_class


def report_scores(confusion_matrix):
    overall_accuracy = get_overall_accuracy(confusion_matrix) * 100
    print("Confusion matrix with overall accuracy: {:.2f}%".format(overall_accuracy))
    print_matrix(confusion_matrix)
    prediction_scores = score_predictions(confusion_matrix)
    print("MIOU: {}".format(get_mean_intersection_over_union(prediction_scores)))
    for index, prediction_score in enumerate(prediction_scores):
        print("Class {:2d} ({:^17}), IOU: {:.4f}".format(LABELS[index], LABELS_OBJ[LABELS[index]],
                                                         prediction_score.get_iou()))
    print()


def score_prediction_files(ground_truth_file, prediction_file):
    """
    Scores a list of prediction files
    :param ground_truth_file: Ground truth classification file
    :param prediction_files: Array of classification prediction files
    :return: None
    """
    print("Scoring {} against {}:".format(prediction_file, ground_truth_file))

    # Create default confusion matrix
    dim = len(LABELS)
    confusion_matrix = np.zeros((dim, dim + 1), np.uint)

    # Load ground truth data
    with open(str(ground_truth_file), 'r') as file:
        try:
            gt_data = [int(line) for line in file]
        except ValueError:
            print("Error reading {}".format(ground_truth_file))
            return confusion_matrix

    # Load prediction data
    with open(str(prediction_file), 'r') as file:
        try:
            pd_data = [int(line) for line in file]
        except ValueError:
            print("Error reading {}".format(ground_truth_file))
            return confusion_matrix

    # Error check number of values
    if len(gt_data) != len(pd_data):
        print("Mismatched file lengths!")
        return confusion_matrix

    # Matches line i to line i, creating an array of tuples (ground_truth[i], prediction[i])
    one_to_one = zip(gt_data, pd_data)

    confusion_matrix = generate_confusion_matrix(one_to_one)
    print("Scores for {} (truth: {}):".format(prediction_file, ground_truth_file))
    report_scores(confusion_matrix)
    return confusion_matrix


def directory_type(arg_string):
    """
    Allows arg parser to handle directories
    :param arg_string: A path, relative or absolute to a folder
    :return: A python pure path object to a directory.
    """
    directory_path = Path(arg_string)
    if directory_path.exists() and directory_path.is_dir():
        return directory_path
    raise argparse.ArgumentError("{} is not a valid directory.".format(arg_string))


def file_type(arg_string):
    """
    Allows arg parser to check against files
    :param arg_string: A path, relative or absolute to a file
    :return: A python pure path object to a file.
    """
    file_path = Path(arg_string)
    if file_path.exists() and file_path.is_file():
        return file_path
    raise argparse.ArgumentError("{} is not a valid directory.".format(arg_string))


def get_list_of_files(directory_path):
    p = re.compile(r'[A-Z]{3}_\d{3,}_.*')
    # First try to find text files that include 'CLS' in their name
    classification_files = [Path(file) for file in directory_path.glob('*CLS*.txt') if p.match(file.name)]
    if not classification_files:
        # Fall back to any text files
        classification_files = [Path(file) for file in directory_path.glob('*.txt') if p.match(file.name)]
    if not classification_files:
        raise ValueError("Could not find classification files in {}".format(directory_path))
    return sorted(classification_files)


def get_tile_name(file):
    p = re.compile(r'([A-Z]{3}_\d{3,})_.*')
    return p.match(file.name).group(1)


def match_file_pairs(A, B):
    if not type(A) == list:
        a_tile = get_tile_name(A)
        b = next(b for b in B if get_tile_name(b) == a_tile)
        if b:
            return (A, b)
        else:
            raise ValueError("Could not match {}".format(A))
    elif not type(B) == list:
        return match_file_pairs(B, A)[::-1]
    else:
        return [match_file_pairs(a, B) for a in A]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--ground_truth_directory', type=directory_type)
    parser.add_argument('-t', '--ground_truth_file', type=file_type)
    parser.add_argument('-d', '--prediction_directory', type=directory_type)
    parser.add_argument('-f', '--prediction_file', type=file_type)
    args = parser.parse_args()

    # Get list of truth files
    truth_files = []
    if args.ground_truth_file is not None:
        truth_files.append(args.ground_truth_file)
    if args.ground_truth_directory is not None:
        truth_files.extend(get_list_of_files(args.ground_truth_directory))
    if not truth_files:
        raise ValueError('No ground truth paths specified')

    # Get list of class prediction files
    prediction_files = []
    if args.prediction_file is not None:
        prediction_files.append(args.prediction_file)
    if args.prediction_directory is not None:
        prediction_files.extend(get_list_of_files(args.prediction_directory))
    if not prediction_files:
        raise ValueError('No prediction paths specified')

    # Match truth to prediction files
    file_pairs = match_file_pairs(truth_files, prediction_files)
    if not type(file_pairs) == list:
        file_pairs = [file_pairs]

    confusion_matrix = np.zeros((len(LABELS), len(LABELS) + 1), np.uint)
    for file_pair in file_pairs:
        confusion_matrix += score_prediction_files(*file_pair)

    if len(file_pairs) > 1:
        print("----- OVERALL SCORES -----")
        report_scores(confusion_matrix)
